Explaining black box models through counterfactuals

Patrick Altmeyer

Overview

  • Intro
  • Methodological background
  • Examples in CounterfactualExplanations.jl
  • Possible research questions

Introduction

Have you ever …

  • … received an automated rejection email? Why didn’t you “mEet tHe sHoRtLisTiNg cRiTeRia”? 🙃
  • … used deep learning or some other black-box model? 🔮 Could you explain the model behvaiour intuitively? 👀
  • … used a black-box model at ING to classify counterparties or clients? 🦁
  • … worked for the belastingdienst? 🫠

The need for explanations

  • From human to data-driven decision-making:
    • Today, it is more likely than not that your digital loan or employment application will be handled by an algorithm, at least in the first instance.
  • Black-box models create undesirable dynamics:
    • Human operators in charge of the system have to rely on it blindy.
    • Those indviduals subject to it generally have no way to challenge an outcome.

“You cannot appeal to (algorithms). They do not listen. Nor do they bend.”

— Cathy O’Neil in Weapons of Math Destruction, 2016

Enter: counterfactual explanations

From 🐱 to 🐶

We have fitted some black box classifier to divide cats and dogs. One 🐱 is friends with a lot of cool 🐶 and wants to remain part of that group. The counterfactual path below shows her how to fool the classifier:

Limited software availability

  • Some of the existing approaches scattered across different GitHub repositories (🐍).
  • Only one unifiying Python 🐍 library: CARLA (Pawelczyk et al. 2021).
    • Comprehensive and (somewhat) extensible …
    • … but not language-agnostic and some desirable functionality not supported.
  • Both R and Julia lacking any kind of implementation. Until now …

Enter: CounterfactualExplanations.jl 📦

Dev Build Status Coverage

  • A unifying framework for generating counterfactual explanations and aglorithmic recourse.
  • Built in Julia, but essentially language agnostic: supporting explanations for models built in Python, R, …
  • Designed to be easily extensible through multiple dispatch.
  • Native support for differentiable models built and trained in Julia.

Julia is fast, transparent, beautiful and open 🔴🟢🟣

Explainable AI (XAI)

  • interpretable = inherently interpretable model, no extra tools needed (GLM, decision trees, rules, …) (Rudin 2019)
  • explainable = inherently not interpretable model, but explainable through XAI

Post-hoc explainability:

  • Local surrogate explainers like LIME and Shapley: useful and popular, but …
    • … can be easily fooled (Slack et al. 2020)
    • … rely on reasonably interpretable features.
    • … rely on the concept of fidelity.
  • Counterfactual explanations explain how inputs into a system need to change for it to produce different decisions.
    • Always full-fidelity, since no proxy involved.
    • Intuitive interpretation and straight-forward implemenation.
    • Works well with Bayesian models. Clear link to Causal Inference.
    • Does not rely on interpretable features.
  • Realistic and actionable changes can be used for the purpose of algorithmic recourse.

Counterfactual Explanations

A framework for counterfactual explanations

  • Objective originally proposed by Wachter, Mittelstadt, and Russell (2017) is as follows where \(h\) relates to the complexity of the counterfactual and \(M\) denotes the classifier:

\[ \min_{x\prime \in \mathcal{X}} h(x\prime) \ \ \ \mbox{s. t.} \ \ \ M(x\prime) = t \qquad(1)\]

  • Typically approximated through regularization:

\[ x\prime = \arg \min_{x\prime} \ell(M(x\prime),t) + \lambda h(x\prime) \qquad(2)\]

So counterfactual search is just gradient descent in the feature space 💡 Easy right?

Not so fast …

Effective counterfactuals should meet certain criteria ✅

  • closeness: the average distance between factual and counterfactual features should be small (Wachter, Mittelstadt, and Russell (2017))
  • actionability: the proposed feature perturbation should actually be actionable (Ustun, Spangher, and Liu (2019), Poyiadzi et al. (2020))
  • plausibility: the counterfactual explanation should be plausible to a human (Joshi et al. (2019))
  • unambiguity: a human should have no trouble assigning a label to the counterfactual (Schut et al. (2021))
  • sparsity: the counterfactual explanation should involve as few individual feature changes as possible (Schut et al. (2021))
  • robustness: the counterfactual explanation should be robust to domain and model shifts (Upadhyay, Joshi, and Lakkaraju (2021))
  • diversity: ideally multiple diverse counterfactual explanations should be provided (Mothilal, Sharma, and Tan (2020))
  • causality: counterfactual explanations reflect the structual causal model underlying the data generating process (Karimi et al. (2020), Karimi, Schölkopf, and Valera (2021))

The Bayesian approach - a catchall?

  • Schut et al. (2021) note that different approaches just work with different complexity functions (\(h(x\prime)\) in Equation 1)
  • They show that for classifiers \(\mathcal{\widetilde{M}}\) that incoporate predictive uncertainty we can drop the complexity penalty altogether:

\[ x\prime = \arg \min_{x\prime} \ell(M(x\prime),t) \ \ , \ \ \forall M\in\mathcal{\widetilde{M}} \qquad(3)\]

CounterfactualExplanations.jl: getting started

Installation

  1. Install Julia.
  2. Install the package:
using Pkg
Pkg.add("CounterfactualExplanations")
  1. Explain your black box 🔮

A simple generic generator

# Data:
using CounterfactualExplanations.Data
Random.seed!(1234)
N = 25
xs, ys = Data.toy_data_linear(N)
X = hcat(xs...)
counterfactual_data = CounterfactualData(X,ys')

# Model
using CounterfactualExplanations.Models: LogisticModel, probs 
# Logit model:
w = [1.0 1.0] # true coefficients
b = 0
M = LogisticModel(w, [b])

# Randomly selected factual:
Random.seed!(123)
x = select_factual(counterfactual_data,rand(1:size(X)[2]))
y = round(probs(M, x)[1])
target = ifelse(y==1.0,0.0,1.0) # opposite label as target

# Counterfactual search:
generator = GenericGenerator()
counterfactual = generate_counterfactual(x, target, counterfactual_data, M, generator)

A greedy generator

# Model:
using LinearAlgebra
Σ = Symmetric(reshape(randn(9),3,3).*0.01 + UniformScaling(1)) # MAP covariance matrix
μ = hcat(b, w)
M = CounterfactualExplanations.Models.BayesianLogisticModel(μ, Σ)

# Counterfactual search:
generator = GreedyGenerator(;δ=0.1,n=25))
counterfactual = generate_counterfactual(x, target, counterfactual_data, M, generator)

Custom models and interoperability

Subtyping and dispatch

using Flux, RCall
using CounterfactualExplanations, CounterfactualExplanations.Models
import CounterfactualExplanations.Models: logits, probs # import functions in order to extend

# Step 1)
struct TorchNetwork <: Models.AbstractFittedModel
    nn::Any
end

# Step 2)
function logits(M::TorchNetwork, X::AbstractArray)
  nn = M.nn
= rcopy(R"as_array($nn(torch_tensor(t($X))))")
= isa(ŷ, AbstractArray) ? ŷ : [ŷ]
  return'
end
probs(M::TorchNetwork, X::AbstractArray)= σ.(logits(M, X))
M = TorchNetwork(R"model")

Gradient access

import CounterfactualExplanations.Generators: ∂ℓ
using LinearAlgebra

# Countefactual loss:
function ∂ℓ(generator::AbstractGradientBasedGenerator, counterfactual_state::CounterfactualState) 
  M = counterfactual_state.M
  nn = M.nn
  x′ = counterfactual_state.x′
  t = counterfactual_state.target_encoded
  R"""
  x <- torch_tensor($x′, requires_grad=TRUE)
  output <- $nn(x)
  obj_loss <- nnf_binary_cross_entropy_with_logits(output,$t)
  obj_loss$backward()
  """
  grad = rcopy(R"as_array(x$grad)")
  return grad
end

Subtyping and dispatch

using Flux, PyCall
using CounterfactualExplanations, CounterfactualExplanations.Models
import CounterfactualExplanations.Models: logits, probs # import functions in order to extend

# Step 1)
struct PyTorchNetwork <: Models.AbstractFittedModel
    nn::Any
end

# Step 2)
function logits(M::PyTorchNetwork, X::AbstractArray)
  nn = M.nn
  if !isa(X, Matrix)
    X = reshape(X, length(X), 1)
  end
= py"$nn(torch.Tensor($X).T).detach().numpy()"
= isa(ŷ, AbstractArray) ? ŷ : [ŷ]
  return
end
probs(M::PyTorchNetwork, X::AbstractArray)= σ.(logits(M, X))
M = PyTorchNetwork(py"model")

Gradient access

import CounterfactualExplanations.Generators: ∂ℓ
using LinearAlgebra

# Countefactual loss:
function ∂ℓ(generator::AbstractGradientBasedGenerator, counterfactual_state::CounterfactualState) 
  M = counterfactual_state.M
  nn = M.nn
  x′ = counterfactual_state.x′
  t = counterfactual_state.target_encoded
  x = reshape(x′, 1, length(x′))
  py"""
  x = torch.Tensor($x)
  x.requires_grad = True
  t = torch.Tensor($[t]).squeeze()
  output = $nn(x).squeeze()
  obj_loss = nn.BCEWithLogitsLoss()(output,t)
  obj_loss.backward()
  """
  grad = vec(py"x.grad.detach().numpy()")
  return grad
end

Custom generators

Subtyping

# Abstract suptype:
abstract type AbstractDropoutGenerator <: AbstractGradientBasedGenerator end

# Constructor:
struct DropoutGenerator <: AbstractDropoutGenerator
    loss::Symbol # loss function
    complexity::Function # complexity function
    mutability::Union{Nothing,Vector{Symbol}} # mutibility constraints 
    λ::AbstractFloat # strength of penalty
    ϵ::AbstractFloat # step size
    τ::AbstractFloat # tolerance for convergence
    p_dropout::AbstractFloat # dropout rate
end

# Instantiate:
using LinearAlgebra
generator = DropoutGenerator(
    :logitbinarycrossentropy,
    norm,
    nothing,
    0.1,
    0.1,
    1e-5,
    0.5
)

Dispatch

import CounterfactualExplanations.Generators: generate_perturbations, ∇
using StatsBase
function generate_perturbations(generator::AbstractDropoutGenerator, counterfactual_state::CounterfactualState)
    𝐠ₜ = (generator, counterfactual_state) # gradient
    # Dropout:
    set_to_zero = sample(1:length(𝐠ₜ),Int(round(generator.p_dropout*length(𝐠ₜ))),replace=false)
    𝐠ₜ[set_to_zero] .= 0
    Δx′ = - (generator.ϵ .* 𝐠ₜ) # gradient step
    return Δx′
end

Feature constraints

Mutability constraints can be added at the preprocessing stage:

counterfactual_data = CounterfactualData(X,ys';domain=[(-Inf,Inf),(-Inf,-0.5)])

Application to MNIST

Counterfactuals for image data

This looks nice 🤓

And this … ugh 🥴

Discussion and outlook

The package 📦

Research topics (1) - student project

What happens once AR has actually been implemented? 👀

Research topics (2)

  • An effortless way to incorporate model uncertainty (w/o need for expensive generative model): Laplace Redux.
  • Counterfactual explanations for time series data.
  • Is CE really more intuitive? Could run a user-based study like in Kaur et al. (2020).
  • More ideas form your side? 🤗

More resources

References

Joshi, Shalmali, Oluwasanmi Koyejo, Warut Vijitbenjaronk, Been Kim, and Joydeep Ghosh. 2019. “Towards Realistic Individual Recourse and Actionable Explanations in Black-Box Decision Making Systems.” arXiv Preprint arXiv:1907.09615.
Karimi, Amir-Hossein, Bernhard Schölkopf, and Isabel Valera. 2021. “Algorithmic Recourse: From Counterfactual Explanations to Interventions.” In Proceedings of the 2021 ACM Conference on Fairness, Accountability, and Transparency, 353–62.
Karimi, Amir-Hossein, Julius Von Kügelgen, Bernhard Schölkopf, and Isabel Valera. 2020. “Algorithmic Recourse Under Imperfect Causal Knowledge: A Probabilistic Approach.” arXiv Preprint arXiv:2006.06831.
Kaur, Harmanpreet, Harsha Nori, Samuel Jenkins, Rich Caruana, Hanna Wallach, and Jennifer Wortman Vaughan. 2020. “Interpreting Interpretability: Understanding Data Scientists’ Use of Interpretability Tools for Machine Learning.” In Proceedings of the 2020 CHI Conference on Human Factors in Computing Systems, 1–14.
Mothilal, Ramaravind K, Amit Sharma, and Chenhao Tan. 2020. “Explaining Machine Learning Classifiers Through Diverse Counterfactual Explanations.” In Proceedings of the 2020 Conference on Fairness, Accountability, and Transparency, 607–17.
Pawelczyk, Martin, Sascha Bielawski, Johannes van den Heuvel, Tobias Richter, and Gjergji Kasneci. 2021. “Carla: A Python Library to Benchmark Algorithmic Recourse and Counterfactual Explanation Algorithms.” arXiv Preprint arXiv:2108.00783.
Poyiadzi, Rafael, Kacper Sokol, Raul Santos-Rodriguez, Tijl De Bie, and Peter Flach. 2020. “FACE: Feasible and Actionable Counterfactual Explanations.” In Proceedings of the AAAI/ACM Conference on AI, Ethics, and Society, 344–50.
Rudin, Cynthia. 2019. “Stop Explaining Black Box Machine Learning Models for High Stakes Decisions and Use Interpretable Models Instead.” Nature Machine Intelligence 1 (5): 206–15.
Schut, Lisa, Oscar Key, Rory Mc Grath, Luca Costabello, Bogdan Sacaleanu, Yarin Gal, et al. 2021. “Generating Interpretable Counterfactual Explanations by Implicit Minimisation of Epistemic and Aleatoric Uncertainties.” In International Conference on Artificial Intelligence and Statistics, 1756–64. PMLR.
Slack, Dylan, Sophie Hilgard, Emily Jia, Sameer Singh, and Himabindu Lakkaraju. 2020. “Fooling Lime and Shap: Adversarial Attacks on Post Hoc Explanation Methods.” In Proceedings of the AAAI/ACM Conference on AI, Ethics, and Society, 180–86.
Upadhyay, Sohini, Shalmali Joshi, and Himabindu Lakkaraju. 2021. “Towards Robust and Reliable Algorithmic Recourse.” arXiv Preprint arXiv:2102.13620.
Ustun, Berk, Alexander Spangher, and Yang Liu. 2019. “Actionable Recourse in Linear Classification.” In Proceedings of the Conference on Fairness, Accountability, and Transparency, 10–19.
Wachter, Sandra, Brent Mittelstadt, and Chris Russell. 2017. “Counterfactual Explanations Without Opening the Black Box: Automated Decisions and the GDPR.” Harv. JL & Tech. 31: 841.